import argparse
from pathlib import Path
import time
from glob import glob
import os
import shutil
from tracemalloc import start

import torch
import wandb  # Quit early if user doesn't have wandb installed.
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

# from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE, DiscreteVAE, DALLE
from dalle_pytorch.dalle_pytorch_ori import OpenAIDiscreteVAE, VQGanVAE, DiscreteVAE, DALLE_PG_Discrete, DiscretePGVAE
# from dalle_pytorch.dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE, DiscreteVAE, DALLE
from dalle_pytorch import distributed_utils
from dalle_pytorch.loader import TextImageDataset, TextPtsDataset
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer, YttmTokenizer

# libraries needed for webdataset support

import webdataset as wds
from torchvision import transforms as T
from PIL import Image
from io import BytesIO

from shape2prog.dataset import Synthesis3D

from IPython import embed
# argument parsing

import h5py
import numpy as np
from torch.utils.data import Dataset

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
import comm as comm
from tensorboardX import SummaryWriter
import time
from datetime import timedelta

def normalize_points_torch(points):
    """Normalize point cloud

    Args:
        points (torch.Tensor): (batch_size, num_points, 3)

    Returns:
        torch.Tensor: normalized points

    """
    assert points.dim() == 3 and points.size(2) == 3
    centroid = points.mean(dim=1, keepdim=True)
    points = points - centroid
    norm, _ = points.norm(dim=2, keepdim=True).max(dim=1, keepdim=True)
    new_points = points / norm
    return new_points

def setup_ddp(gpu, args):
    dist.init_process_group(                                   
    	backend='nccl',      # backend='gloo',#                                    
   		init_method='env://',     
    	world_size=args.world_size,                              
    	rank=gpu)

    torch.manual_seed(0)
    torch.cuda.set_device(gpu)


def train(rank, args):
    if args.gpus > 1:
        setup_ddp(rank, args)

    def exists(val):
        return val is not None

    def get_trainable_params(model):
        return [params for params in model.parameters() if params.requires_grad]

    def get_pkg_version():
        from pkg_resources import get_distribution
        return get_distribution('dalle_pytorch').version

    def cp_path_to_dir(cp_path, tag):
        """Convert a checkpoint path to a directory with `tag` inserted.
        If `cp_path` is already a directory, return it unchanged.
        """
        if not isinstance(cp_path, Path):
            cp_path = Path(cp_path)
        if cp_path.is_dir():
            return cp_path
        path_sans_extension = cp_path.parent / cp_path.stem
        cp_dir = Path(f'{path_sans_extension}-{tag}-cp')
        return cp_dir

    # constants

    WEBDATASET_IMAGE_TEXT_COLUMNS = tuple(args.wds.split(','))
    ENABLE_WEBDATASET = True if len(WEBDATASET_IMAGE_TEXT_COLUMNS) == 2 else False

    DALLE_OUTPUT_FILE_NAME = args.dalle_output_file_name + args.save_name + ".pt"



    EPOCHS = args.epochs
    BATCH_SIZE = args.batch_size

    LEARNING_RATE = args.learning_rate
    GRAD_CLIP_NORM = args.clip_grad_norm
    LR_DECAY = args.lr_decay
    SAVE_EVERY_N_STEPS = args.save_every_n_steps
    KEEP_N_CHECKPOINTS = args.keep_n_checkpoints

    VAE_PATH = os.path.join('./outputs/vae_models', args.vae_path)
    PGVAE_PATH = os.path.join('./outputs/vae_models', args.pgvae_path)
    VQGAN_MODEL_PATH = args.vqgan_model_path
    VQGAN_CONFIG_PATH = args.vqgan_config_path
    #DALLE_PATH = os.path.join('./outputs/dalle_models',args.dalle_path)
    #RESUME = exists(DALLE_PATH)

    MODEL_DIM = args.dim
    TEXT_SEQ_LEN = args.text_seq_len
    DEPTH = args.depth
    HEADS = args.heads
    DIM_HEAD = args.dim_head
    REVERSIBLE = args.reversible
    LOSS_IMG_WEIGHT = args.loss_img_weight
    FF_DROPOUT = args.ff_dropout
    ATTN_DROPOUT = args.attn_dropout
    STABLE = args.stable_softmax
    SHIFT_TOKENS = args.shift_tokens
    ROTARY_EMB = args.rotary_emb
    ATTN_TYPES = tuple(args.attn_types.split(','))
    SHARED_ATTN_IDS = tuple(args.shared_attn_ids.split(',')) if exists(args.shared_attn_ids) else None
    SHARED_FF_IDS = tuple(args.shared_ff_ids.split(',')) if exists(args.shared_ff_ids) else None
    SHARE_INPUT_OUTPUT_EMB = args.share_input_output_emb



    class TextPGPC_Dataset(Dataset):
        def __init__(self, path):
            self.ds_textpc = TextImageDataset(
                os.path.join('/home/tiangel/datasets', path),
                text_len=TEXT_SEQ_LEN,
                resize_ratio=args.resize_ratio,
                truncate_captions=args.truncate_captions,
                tokenizer=tokenizer,
                shuffle=False,
            )
            self.ds_pgpc = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc.h5',10)
            # self.ds_pgpc = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_table_40k.h5',10)
            # self.ds_pgpc = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_chair_160k.h5',10)
            self.len1 = len(self.ds_textpc)
            self.len2 = len(self.ds_pgpc)
            self.len = self.len2
        def __getitem__(self, index):
            id1 = index % self.len1
            text_pc = normalize_points_torch(self.ds_textpc[id1][1].unsqueeze(0)).squeeze(0)
            scale = text_pc.new(1).uniform_(0.9, 1.05)
            text_pc *= scale
            pgm_pc = normalize_points_torch(torch.Tensor(self.ds_pgpc[index][0]).unsqueeze(0)).squeeze(0)
            scale = pgm_pc.new(1).uniform_(0.9, 1.05)
            pgm_pc *= scale
            return (self.ds_textpc[id1][0], text_pc), (pgm_pc.numpy(), self.ds_pgpc[index][1], self.ds_pgpc[index][2], self.ds_pgpc[index][3], self.ds_pgpc[index][4])
            # return self.ds_textpc[id1], self.ds_pgpc[index]

        def __len__(self):
            return self.len

    class TextPGPC_Dataset2(Dataset):
        def __init__(self, path1, path2):
            self.ds_textpc1 = TextPtsDataset(
                path1,
                text_len=TEXT_SEQ_LEN,
                resize_ratio=args.resize_ratio,
                truncate_captions=args.truncate_captions,
                tokenizer=tokenizer,
                shuffle=False,
            )
            self.ds_textpc2 = TextPtsDataset(
                path2,
                text_len=TEXT_SEQ_LEN,
                resize_ratio=args.resize_ratio,
                truncate_captions=args.truncate_captions,
                tokenizer=tokenizer,
                shuffle=False,
            )
            # pc_chair_80k
            self.ds_pgpc1 = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc.h5',10)
            # self.ds_pgpc1 = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_chair_40k.h5',10)
            self.ds_pgpc2 = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_table_40k.h5',10)
            # self.ds_pgpc = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_table_40k.h5',10)
            # self.ds_pgpc = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_chair_160k.h5',10)
            self.len11 = len(self.ds_textpc1)
            self.len12 = len(self.ds_textpc2)
            self.len1 = self.len11 + self.len12
            self.len21 = len(self.ds_pgpc1)
            self.len22 = len(self.ds_pgpc2)
            self.len = self.len21 + self.len22
        def __getitem__(self, index):
            id1 = index % self.len1
            if id1 < self.len11:
                text_pc = self.ds_textpc1[id1][1]
                text = self.ds_textpc1[id1][0]
            else:
                text_pc = self.ds_textpc2[id1-self.len11][1]
                text = self.ds_textpc2[id1-self.len11][0]
            if index < self.len21:
                pgm_pc = normalize_points_torch(torch.Tensor(self.ds_pgpc1[index][0]).unsqueeze(0)).squeeze(0)
                scale = pgm_pc.new(1).uniform_(0.9, 1.05)
                pgm_pc *= scale
                pgm_return = (pgm_pc.numpy(), self.ds_pgpc1[index][1], self.ds_pgpc1[index][2], self.ds_pgpc1[index][3], self.ds_pgpc1[index][4])
            else:
                index -= self.len21
                pgm_pc = normalize_points_torch(torch.Tensor(self.ds_pgpc2[index][0]).unsqueeze(0)).squeeze(0)
                scale = pgm_pc.new(1).uniform_(0.9, 1.05)
                pgm_pc *= scale
                pgm_return = (pgm_pc.numpy(), self.ds_pgpc2[index][1], self.ds_pgpc2[index][2], self.ds_pgpc2[index][3], self.ds_pgpc2[index][4])
            return (text, text_pc), pgm_return
            # return self.ds_textpc[id1], self.ds_pgpc[index]

        def __len__(self):
            return self.len
    # initialize distributed backend
    class TextPGPCPP_Dataset2(Dataset):
        def __init__(self, path1, path2, path3):
            self.ds_textpc1 = TextPtsDataset(
                path1,
                text_len=TEXT_SEQ_LEN,
                resize_ratio=args.resize_ratio,
                truncate_captions=args.truncate_captions,
                tokenizer=tokenizer,
                shuffle=False,
            )
            self.ds_textpc2 = TextPtsDataset(
                path2,
                text_len=TEXT_SEQ_LEN,
                resize_ratio=args.resize_ratio,
                truncate_captions=args.truncate_captions,
                tokenizer=tokenizer,
                shuffle=False,
            )
            self.ds_textpc3 = TextPtsDataset(
                path3,
                text_len=TEXT_SEQ_LEN,
                resize_ratio=args.resize_ratio,
                truncate_captions=args.truncate_captions,
                tokenizer=tokenizer,
                shuffle=False,
            )
            # pc_chair_80k
            self.ds_pgpc1 = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc.h5',10)
            # self.ds_pgpc1 = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_chair_40k.h5',10)
            self.ds_pgpc2 = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_table_40k.h5',10)
            # self.ds_pgpc = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_table_40k.h5',10)
            # self.ds_pgpc = Synthesis3D('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_chair_160k.h5',10)
            self.ds_pp = h5py.File('/home/tiangel/datasets/completion_data.h5', 'r')
            self.ds_pp_inputs = self.ds_pp['inputs']
            self.ds_pp_targets = self.ds_pp['targets']
            self.len3 = self.ds_pp_inputs.shape[0]
            self.len11 = len(self.ds_textpc1)
            self.len12 = len(self.ds_textpc2)
            self.len13 = len(self.ds_textpc3)
            self.len1 = self.len11 + self.len12 + self.len13
            self.len21 = len(self.ds_pgpc1)
            self.len22 = len(self.ds_pgpc2)
            self.len = self.len21 + self.len22
        def __getitem__(self, index):
            id1 = index % self.len1
            if id1 < self.len11:
                text_pc = self.ds_textpc1[id1][1]
                text = self.ds_textpc1[id1][0]
            elif id1 >= self.len11 and id1 < self.len11 + self.len12:
                text_pc = self.ds_textpc2[id1-self.len11][1]
                text = self.ds_textpc2[id1-self.len11][0]
            else:
                text_pc = self.ds_textpc3[id1-self.len11-self.len12][1]
                text = self.ds_textpc3[id1-self.len11-self.len12][0]

            if index < self.len21:
                pgm_pc = normalize_points_torch(torch.Tensor(self.ds_pgpc1[index][0]).unsqueeze(0)).squeeze(0)
                scale = pgm_pc.new(1).uniform_(0.9, 1.05)
                pgm_pc *= scale
                pgm_return = (pgm_pc.numpy(), self.ds_pgpc1[index][1], self.ds_pgpc1[index][2], self.ds_pgpc1[index][3], self.ds_pgpc1[index][4])
            else:
                index -= self.len21
                pgm_pc = normalize_points_torch(torch.Tensor(self.ds_pgpc2[index][0]).unsqueeze(0)).squeeze(0)
                scale = pgm_pc.new(1).uniform_(0.9, 1.05)
                pgm_pc *= scale
                pgm_return = (pgm_pc.numpy(), self.ds_pgpc2[index][1], self.ds_pgpc2[index][2], self.ds_pgpc2[index][3], self.ds_pgpc2[index][4])
            id2 = index % self.len3
            input_pp = self.ds_pp_inputs[id2, np.random.randint(20)]
            #modify to a general one
            target_pp = self.ds_pp_targets[id2]
            input_pp = normalize_points_torch(torch.Tensor(input_pp).unsqueeze(0)).squeeze(0)
            scale = input_pp.new(1).uniform_(0.9, 1.05)
            input_pp *= scale
            target_pp = normalize_points_torch(torch.Tensor(target_pp).unsqueeze(0)).squeeze(0)
            target_pp *= scale
            return (text, text_pc), pgm_return, (input_pp, target_pp)
            # return self.ds_textpc[id1], self.ds_pgpc[index]

        def __len__(self):
            return self.len

        
    # 
    # ds = TextPGPC_Dataset(args.image_text_folder)
    ds = TextPGPCPP_Dataset2('shapeglot','abo','text2shape')

    if args.gpus > 1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
                    ds, shuffle=True, num_replicas=args.gpus, rank=rank)
        dl = DataLoader(ds, BATCH_SIZE, sampler=train_sampler, drop_last=True)
    else:
        dl = DataLoader(ds, BATCH_SIZE, drop_last=True)



    if exists(VAE_PATH):
        vae_path = Path(VAE_PATH)
        assert vae_path.exists(), 'VAE model file does not exist'
        assert not vae_path.is_dir(), \
            ('Cannot load VAE model from directory; please use a '
             'standard *.pt checkpoint. '
             'Currently, merging a DeepSpeed-partitioned VAE into a DALLE '
             'model is not supported.')

        loaded_obj = torch.load(str(vae_path))
        vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']
        vae = DiscreteVAE(**vae_params)
        vae.load_state_dict(weights)
        if args.gpus > 1:
            vae.to(rank).eval()
        else:
            vae.to(0).eval()

        del loaded_obj, weights
        torch.cuda.empty_cache()

        # # Tiange for pg_vae
        # pgvae_path = Path(PGVAE_PATH)
        # assert pgvae_path.exists(), 'VAE model file does not exist'
        # assert not pgvae_path.is_dir(), \
        #     ('Cannot load VAE model from directory; please use a '
        #      'standard *.pt checkpoint. '
        #      'Currently, merging a DeepSpeed-partitioned VAE into a DALLE '
        #      'model is not supported.')

        # pg_loaded_obj = torch.load(str(pgvae_path))
        # pgvae_params, pg_weights = pg_loaded_obj['hparams'], pg_loaded_obj['weights']
        # pgvae = DiscretePGVAE(**pgvae_params)
        # pgvae.load_state_dict(pg_weights)
        # pgvae.to(rank).eval()

        # del pg_loaded_obj, pg_weights
        # torch.cuda.empty_cache()

        dalle_params = dict(
            num_text_tokens=tokenizer.vocab_size,
            text_seq_len=TEXT_SEQ_LEN,
            dim=MODEL_DIM,
            depth=DEPTH,
            heads=HEADS,
            dim_head=DIM_HEAD,
            reversible=REVERSIBLE,
            loss_img_weight=LOSS_IMG_WEIGHT,
            attn_types=ATTN_TYPES,
            ff_dropout=FF_DROPOUT,
            attn_dropout=ATTN_DROPOUT,
            stable=STABLE,
            shift_tokens=SHIFT_TOKENS,
            rotary_emb=ROTARY_EMB,
            shared_attn_ids=SHARED_ATTN_IDS,
            shared_ff_ids=SHARED_FF_IDS,
            share_input_output_emb=SHARE_INPUT_OUTPUT_EMB,
            inverse=args.inverse,
        )

    dalle = DALLE_PG_Discrete(vae=vae, **dalle_params).train()

    opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = opt, T_max = EPOCHS*int(len(ds)/BATCH_SIZE/args.gpus))
    if args.resume:
        load_obj = torch.load(os.path.join('./outputs/dalle_models','dalle'+args.save_name+'-cpu.pt'))
        weights, resume_epoch, opt_state, sche_state = load_obj.pop('weights'), load_obj.pop('epoch'), load_obj.pop('opt_state'), load_obj.pop('scheduler_state')
        dalle.load_state_dict(weights)
        dalle.cuda()
        dalle.vae.cuda()
        opt.load_state_dict(opt_state)
        scheduler.load_state_dict(sche_state)
        del weights, load_obj
        torch.cuda.empty_cache()
        #if args.gpus > 1:
        #    dalle.to(rank).eval()
        #else:
        #    dalle.to(0).eval()
    else:
        dalle.cuda()
        resume_epoch = 0

    if args.gpus > 1:
        dalle = DistributedDataParallel(
                dalle, device_ids=[rank], broadcast_buffers=False, find_unused_parameters=True
        )
        
    # experiment tracker

    def save_model(path, epoch=0, gpus=1):
        save_obj = {
            'hparams': dalle_params,
            'vae_params': vae_params,
            'epoch': epoch,
            'version': get_pkg_version(),
            'vae_class_name': vae.__class__.__name__,
        }

        # if gpus == 1:
        if gpus == 1:
            save_obj = {
                **save_obj,
                'weights': dalle.state_dict(),
                'opt_state': opt.state_dict(),
                'scheduler_state': (scheduler.state_dict() if scheduler else None)
            }
        else:
            weights = dalle.state_dict()
            keys = list(weights.keys())
            for k in keys:
                name = '.'.join(k.split('.')[1:])
                weights[name] = weights[k]
                weights.pop(k)
                # print(name,'||',k)

            save_obj = {
                **save_obj,
                'weights': weights,
                'opt_state': opt.state_dict(),
                'scheduler_state': (scheduler.state_dict() if scheduler else None)
            }

        torch.save(save_obj, path)

    # training

    # Saves a checkpoint before training begins to fail early when mis-configured.
    # See https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints

    # if comm.is_main_process():
        # save_model(DALLE_OUTPUT_FILE_NAME, epoch=resume_epoch, gpus=args.gpus)

    write_dir = os.path.join('./outputs/dalle_outputs', 'test'+args.save_name)
    if not os.path.exists(write_dir) and comm.is_main_process():
        os.mkdir(write_dir)
    if comm.is_main_process() and args.tensorboard_flag:
        writer = SummaryWriter(write_dir)
    for epoch in range(resume_epoch, EPOCHS):
        start_time = time.time()
        for i, ((text, pts), pg_data, pp_data) in enumerate(dl):

            text, pts = map(lambda t: t.cuda(), (text, pts))

            if comm.is_main_process() and args.tensorboard_flag:
                loss = dalle(text, pts, pg_data, pp_data, return_loss=True, inverse=args.inverse, fixed_pos = args.fixed_pos, discrete_type = args.discrete_type, completion_only = args.completion_only, textshape_only = args.textshape_only, shapetext_only = args.shapetext_only, writer=writer, global_step=epoch*len(dl)+i)
            else:
                loss = dalle(text, pts, pg_data, pp_data, return_loss=True, inverse=args.inverse, fixed_pos = args.fixed_pos, discrete_type = args.discrete_type, completion_only = args.completion_only, textshape_only = args.textshape_only, shapetext_only = args.shapetext_only, writer=None, global_step=epoch*len(dl)+i)


            opt.zero_grad()
            loss.backward()
            opt.step()
            scheduler.step()

            if i % SAVE_EVERY_N_STEPS == 0:
                if comm.is_main_process():
                    print('epoch:%03d, iteration:%04d, lr:%f'%(epoch, i, scheduler.get_last_lr()[0]))
                    save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch, gpus=args.gpus)
        elapsed_time_secs = time.time() - start_time
        msg = "1 epoch took: %s secs" % timedelta(seconds=round(elapsed_time_secs))
        print(msg)

        # if LR_DECAY:
            # distr_scheduler.step(avg_loss)

        if comm.is_main_process():
            save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch, gpus=args.gpus)


    if comm.is_main_process():
        save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch, gpus=args.gpus)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    group = parser.add_mutually_exclusive_group(required=False)

    group.add_argument('--vae_path', type=str,
                       help='path to your trained discrete VAE')

    group.add_argument('--dalle_path', type=str,
                       help='path to your partially trained DALL-E')

    parser.add_argument('--pgvae_path', type=str, default = None,
                       help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')

    parser.add_argument('--vqgan_model_path', type=str, default = None,
                       help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')

    parser.add_argument('--vqgan_config_path', type=str, default = None,
                       help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)')

    # parser.add_argument('--image_text_folder', type=str, required=True,
                        # help='path to your folder of images and text for learning the DALL-E')

    parser.add_argument('--image_text_folder', type=str, default='DatasetError',
                        help='path to your folder of images and text for learning the DALL-E')

    parser.add_argument('--wds', type = str, default='',
                        help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.')

    parser.add_argument('--truncate_captions', dest='truncate_captions', action='store_true',
                        help='Captions passed in which exceed the max token length will be truncated if this is set.')

    parser.add_argument('--random_resize_crop_lower_ratio', dest='resize_ratio', type=float, default=0.75,
                        help='Random resized crop lower ratio')

    parser.add_argument('--chinese', dest='chinese', action='store_true')

    parser.add_argument('--taming', dest='taming', action='store_true')

    parser.add_argument('--hug', dest='hug', action='store_true')

    parser.add_argument('--bpe_path', type=str,
                        help='path to your BPE json file')

    parser.add_argument('--dalle_output_file_name', type=str, default = "./outputs/dalle_models/dalle",
                        help='output_file_name')

    parser.add_argument('--fp16', action='store_true',
                        help='(experimental) - Enable DeepSpeed 16 bit precision. Reduces VRAM.')


    parser.add_argument('--amp', action='store_true',
    	               help='Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\'t be used in conjunction with deepspeed zero stages 1-3.')

    parser.add_argument('--wandb_name', default='dalle_train_transformer',
                        help='Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`')

    parser.add_argument('--wandb_entity', default=None,
                        help='(optional) Name of W&B team/entity to log to.')

    parser.add_argument('--stable_softmax', dest='stable_softmax', action='store_true',
                        help='Prevent values from becoming too large during softmax. Helps with stability in fp16 and Mixture of Quantization training.')

    parser = distributed_utils.wrap_arg_parser(parser)

    train_group = parser.add_argument_group('Training settings')

    train_group.add_argument('--flops_profiler', dest = 'flops_profiler', action='store_true', help = 'Exits after printing detailed flops/runtime analysis of forward/backward')

    train_group.add_argument('--epochs', default = 20, type = int, help = 'Number of epochs')

    train_group.add_argument('--save_every_n_steps', default = 100, type = int, help = 'Save a checkpoint every n steps')

    train_group.add_argument('--keep_n_checkpoints', default = None, type = int, help = '(Careful) Deletes old deepspeed checkpoints if there are more than n')

    train_group.add_argument('--batch_size', default = 4, type = int, help = 'Batch size')

    train_group.add_argument('--ga_steps', default = 1, type = int, help = 'Number of steps to accumulate gradients across per each iteration. DeepSpeed only.')

    train_group.add_argument('--learning_rate', default = 1e-3, type = float, help = 'Learning rate')

    train_group.add_argument('--clip_grad_norm', default = 0.5, type = float, help = 'Clip gradient norm')

    train_group.add_argument('--lr_decay', dest = 'lr_decay', action = 'store_true')

    model_group = parser.add_argument_group('Model settings')

    model_group.add_argument('--dim', default = 512, type = int, help = 'Model dimension')

    model_group.add_argument('--text_seq_len', default = 256, type = int, help = 'Text sequence length')

    model_group.add_argument('--depth', default = 2, type = int, help = 'Model depth')

    model_group.add_argument('--heads', default = 8, type = int, help = 'Model number of heads')

    model_group.add_argument('--dim_head', default = 64, type = int, help = 'Model head dimension')

    train_group.add_argument('--ff_dropout', default = 0.0, type = float, help = 'Feed forward dropout.')

    train_group.add_argument('--attn_dropout', default = 0.0, type = float, help = 'Feed forward dropout.')

    model_group.add_argument('--reversible', dest = 'reversible', action='store_true')

    model_group.add_argument('--loss_img_weight', default = 7, type = int, help = 'Image loss weight')

    model_group.add_argument('--attn_types', default = 'full', type = str, help = 'comma separated list of attention types. attention type can be: full or sparse or axial_row or axial_col or conv_like.')

    model_group.add_argument('--shift_tokens', help = 'Use the shift tokens feature', action = 'store_true')

    model_group.add_argument('--rotary_emb', help = 'Use rotary embeddings', action = 'store_true')

    model_group.add_argument('--shared_attn_ids', default = None, type = str, help = 'Comma separated list of shared attention layer ids. Default: sharing is disabled')

    model_group.add_argument('--shared_ff_ids', default = None, type = str, help = 'Comma separated list of shared feed forward layer ids. Default: sharing is disabled')

    model_group.add_argument('--share_input_output_emb', help = 'Share input and output embeddings', action = 'store_true')

    model_group.add_argument('--save_name', type = str, default = '1', help = 'KL loss weight')

    model_group.add_argument('--port', type = str, default = '12358', help = 'port for parallel')

    model_group.add_argument('--gpus', type = int, default = 1, help = 'hidden dimension')

    model_group.add_argument('--inverse', type = bool, default = False, help = 'inverse feeding')

    model_group.add_argument('--fixed_pos', type = bool, default = False, help = 'inverse feeding')

    model_group.add_argument('--discrete_type', type = int, default = 1, help = 'inverse feeding')

    model_group.add_argument('--resume', type = bool, default = False, help = 'inverse feeding')

    model_group.add_argument('--dataset', type = str, default = 'all', help = 'inverse feeding')

    model_group.add_argument('--completion_only', type = bool, default = False, help = 'inverse feeding')

    model_group.add_argument('--tensorboard_flag', type = bool, default = False, help = 'inverse feeding')

    model_group.add_argument('--textshape_only', type = bool, default = False, help = 'inverse feeding')

    model_group.add_argument('--shapetext_only', type = bool, default = False, help = 'inverse feeding')

    args = parser.parse_args()

    if args.gpus == 1:
        train(args.gpus, args)
    else:
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = args.port
        args.world_size = args.gpus
        mp.spawn(train, nprocs=args.gpus, args=(args,))